Univalued binary tree [DFS]

Time: O(N); Space: O(H); easy

A binary tree is univalued if every node in the tree has the same value.

Return true if and only if the given tree is univalued.

Example 1:

Input: root = {TreeNode} [1,1,1,1,1,null,1]

Output: True

Example 2:

Input: root = {TreeNode} [2,2,2,5,2]

Output: False

Notes:

  • The number of nodes in the given tree will be in the range [1, 100].

  • Each node’s value will be an integer in the range [0, 99].

[6]:
class TreeNode(object):
    def __init__(self, x):
        self.val = x
        self.left = None
        self.right = None

1. Depth-First Search [O(N), O(H)]

Intuition and Algorithm

Let’s output all the values of the array. After, we can check that they are all equal.

To output all the values of the array, we perform a depth-first search.

[7]:
class Solution1(object):
    def isUnivalTree(self, root):
        """
        :type root: TreeNode
        :rtype: bool
        """
        vals = []

        def dfs(node):
            if node:
                vals.append(node.val)
                dfs(node.left)
                dfs(node.right)

        dfs(root)
        return len(set(vals)) == 1
[8]:
s = Solution1()

root = TreeNode(1)
root.left = TreeNode(1)
root.right = TreeNode(1)
root.left.left = TreeNode(1)
root.left.right = TreeNode(1)
root.right.right = TreeNode(1)
assert s.isUnivalTree(root) == True

root = TreeNode(2)
root.left = TreeNode(2)
root.right = TreeNode(2)
root.left.left = TreeNode(5)
root.left.right = TreeNode(2)
assert s.isUnivalTree(root) == False
[9]:
class Solution2(object):
    """
    Time: O(N)
    Space: O(H)
    """
    def isUnivalTree(self, root):
        """
        :type root: TreeNode
        :rtype: bool
        """
        s = [root]
        while s:
            node = s.pop()
            if not node:
                continue
            if node.val != root.val:
                return False
            s.append(node.left)
            s.append(node.right)

        return True
[10]:
s = Solution2()

root = TreeNode(1)
root.left = TreeNode(1)
root.right = TreeNode(1)
root.left.left = TreeNode(1)
root.left.right = TreeNode(1)
root.right.right = TreeNode(1)
assert s.isUnivalTree(root) == True

root = TreeNode(2)
root.left = TreeNode(2)
root.right = TreeNode(2)
root.left.left = TreeNode(5)
root.left.right = TreeNode(2)
assert s.isUnivalTree(root) == False

2. Recursion [O(N), O(H)]

Intuition and Algorithm

A tree is univalued if both its children are univalued, plus the root node has the same value as the child nodes.

We can write our function recursively. left_correct will represent that the left child is correct: ie., that it is univalued, and the root value is equal to the left child’s value. right_correct will represent the same thing for the right child. We need both of these properties to be true.

[11]:
class Solution3(object):
    def isUnivalTree(self, root):
        """
        :type root: TreeNode
        :rtype: bool
        """
        left_correct = (not root.left or root.val == root.left.val
                and self.isUnivalTree(root.left))
        right_correct = (not root.right or root.val == root.right.val
                and self.isUnivalTree(root.right))
        return left_correct and right_correct
[12]:
s = Solution3()

root = TreeNode(1)
root.left = TreeNode(1)
root.right = TreeNode(1)
root.left.left = TreeNode(1)
root.left.right = TreeNode(1)
root.right.right = TreeNode(1)
assert s.isUnivalTree(root) == True

root = TreeNode(2)
root.left = TreeNode(2)
root.right = TreeNode(2)
root.left.left = TreeNode(5)
root.left.right = TreeNode(2)
assert s.isUnivalTree(root) == False
[13]:
class Solution4(object):
    def isUnivalTree(self, root):
        """
        :type root: TreeNode
        :rtype: bool
        """
        return (not root.left or (root.left.val == root.val and self.isUnivalTree(root.left))) and \
               (not root.right or (root.right.val == root.val and self.isUnivalTree(root.right)))
[14]:
s = Solution4()

root = TreeNode(1)
root.left = TreeNode(1)
root.right = TreeNode(1)
root.left.left = TreeNode(1)
root.left.right = TreeNode(1)
root.right.right = TreeNode(1)
assert s.isUnivalTree(root) == True

root = TreeNode(2)
root.left = TreeNode(2)
root.right = TreeNode(2)
root.left.left = TreeNode(5)
root.left.right = TreeNode(2)
assert s.isUnivalTree(root) == False